import random

import torch
import torch.nn as nn
from models_interhuman_selfattn.vq.encdec import Encoder, Decoder
from models_interhuman_selfattn.vq.residual_vq import ResidualVQ, MultiVQ
from data.utils import fid_l, fid_r
from utils.paramUtil import t2m_edge_list
from models_interhuman_selfattn.vq.InteracteModel import *

class RVQVAE(nn.Module):
    def __init__(self,
                 args,
                 conv_dim=1,
                 input_width=263,
                 nb_code=1024,
                 code_dim=512,
                 output_emb_width=512,
                 down_t=3,
                 stride_t=2,
                 width=512,
                 depth=3,
                 dilation_growth_rate=3,
                 activation='relu',
                 norm=None):

        super().__init__()
        assert output_emb_width == code_dim
        self.args = args
        self.code_dim = code_dim
        self.num_code = nb_code
        self.conv_dim = conv_dim
        self.joints_num = args.joints_num
        self.dataset_name = args.dataset_name

        if self.dataset_name == "interhuman":
            filter_s = None
            stride_s = None
            spatial_upsample = (2.2, 2)
        elif self.dataset_name == "interx":
            filter_s = 6
            stride_s = 3
            spatial_upsample = (3.5, 3.3) #3.25

        self.encoder = Encoder(args, conv_dim, input_width, output_emb_width, down_t, stride_t, width, depth,
                               dilation_growth_rate, activation=activation, norm=norm, filter_s=filter_s, stride_s=stride_s)
        self.decoder = Decoder(args, conv_dim, input_width, output_emb_width, down_t, stride_t, width, depth,
                               dilation_growth_rate, activation=activation, norm=norm, spatial_upsample=spatial_upsample)
        if self.args.physical_decoder:
            self.physical_decoder= Decoder(args, conv_dim, input_width*2, output_emb_width*2, down_t, stride_t, width, depth,
                                dilation_growth_rate, activation=activation, norm=norm, spatial_upsample=spatial_upsample)
        if self.args.gate_select:
            self.alpha = nn.Parameter(torch.tensor(1.0, dtype=torch.float))
            self.beta = nn.Parameter(torch.tensor(1.0, dtype=torch.float))
        rvqvae_config = {
            'num_quantizers': args.num_quantizers,
            'shared_codebook': args.shared_codebook,
            'nb_code': nb_code,
            'code_dim':code_dim, 
            'args': args,
            'quantize_dropout_prob': args.quantize_dropout_prob,
            'quantize_dropout_cutoff_index': 0
        }
        
        self.quantizer = ResidualVQ(**rvqvae_config)

        physical_rvqvae_config = {
            'num_quantizers': args.num_quantizers,
            'shared_codebook': args.shared_codebook,
            'nb_code': nb_code,
            'code_dim':code_dim*2, 
            'args': args,
            'quantize_dropout_prob': args.quantize_dropout_prob,
            'quantize_dropout_cutoff_index': 0
        }
        
        self.physical_quantizer = ResidualVQ(**physical_rvqvae_config)
        if args.Interact_select=="Origin":
            self.Interaction= CrossInteracteMoudle(output_emb_width,num_heads=8)
        elif args.Interact_select=="middle":            
            self.Interaction= MiddleCrossInteracteMoudle(output_emb_width,num_heads=8)
        elif args.Interact_select=="LORA":
            self.Interaction= LORACrossInteracteMoudle(output_emb_width,num_heads=8)

        # if args.vq_select=="RSA_Code":
        #     self.Interaction= MiddleCrossInteracteMoudle(output_emb_width,num_heads=8)
        # else:
        #     self.Interaction= CrossInteracteMoudle(output_emb_width,num_heads=8)

    def preprocess(self, x):
        if self.conv_dim == 1:
            # (bs, T, Jx3) -> (bs, Jx3, T)
            joints = x.permute(0, 2, 1).float()
        
        elif self.conv_dim == 2:
            if self.dataset_name == "interhuman":
                pos = x[..., :self.joints_num*3].reshape([x.shape[0], x.shape[1], self.joints_num, 3])
                vel = x[..., self.joints_num*3 : self.joints_num*3*2].reshape([x.shape[0], x.shape[1], self.joints_num, 3])
                
                rot = x[..., self.joints_num*3*2 : self.joints_num*3*2 + (self.joints_num-1)*6].reshape([x.shape[0], x.shape[1], self.joints_num-1, 6])
                rot = torch.cat([torch.zeros(rot.shape[0], rot.shape[1], 1, 6).to(x.device), rot], dim=2)
                
                joints = torch.cat([pos, vel, rot], dim=-1)
            else:
                joints = x
            joints = joints.permute(0, 3, 2, 1).float() # B, D=12, J=22, T 
 
        return joints

    def postprocess(self, x):
        if self.conv_dim == 1:
            # (bs, Jx3, T) -> (bs, T, Jx3)
            x = x.permute(0, 2, 1).float()
        
        elif self.conv_dim == 2:
            x = x.permute(0, 3, 2, 1).float()

            if self.dataset_name == "interhuman":
                pos = x[:,:,:,:3].reshape([x.shape[0], x.shape[1], -1])
                vel = x[:,:,:,3:6].reshape([x.shape[0], x.shape[1], -1])
                rot = x[:,:,1:,6:6+6].reshape([x.shape[0], x.shape[1], -1])
                fc = torch.zeros((x.shape[0], x.shape[1], 4)).to(x.device)
                
                x = torch.cat([pos, vel, rot, fc], dim=-1)
        return x

    def encode(self, x):
        # N, T, _, _ = x.shape

        x_in = self.preprocess(x) # B, D=12, J=22, T || B, J=22xD=12, T
        x_encoder = self.encoder(x_in) # B, D=512, 5, T/2 || B, J=7xD=512, T//4
        
        
        encoder_shape = x_encoder.shape
        x_encoder = x_encoder if len(encoder_shape) == 3 else x_encoder.reshape(encoder_shape[0], encoder_shape[1], -1)
        
        code_idx, all_codes = self.quantizer.quantize(x_encoder, return_latent=True) # B,375,1; 1,B,512,375
        return code_idx, all_codes

    def forward(self, x, x2=None, verbose=False):
        
        # Encode
        x_in1 = self.preprocess(x) # B, D=12, J=22, T || B, J=22xD=12, T
        x_in2 = self.preprocess(x2) # B, D=12, J=22, T || B, J=22xD=12, T
        # if verbose: print(f'preprocess: {x_in.shape}')
        
        x_encoder1 = self.encoder(x_in1) # B, D=512, 5, T/2 || B, J=7xD=512, T//4
        x_encoder2 = self.encoder(x_in2) # B, D=512, 5, T/2 || B, J=7xD=512, T//4
        # if verbose: print(f'encoder: {x_encoder.shape}')
        #第一个人编码成功
        ## quantization
        encoder_shape1 = x_encoder1.shape
        x_encoder1 = x_encoder1 if len(encoder_shape1) == 3 else x_encoder1.reshape(encoder_shape1[0], encoder_shape1[1], -1)
        # if verbose: print(f'reshape: {x_encoder.shape}')
        x_quantized1, code_idx1, commit_loss1, perplexity1 = self.quantizer(x_encoder1, sample_codebook_temp=0.5)
        # if verbose: print(f'quantizer: {x_quantized.shape}')
        x_quantized1 = x_quantized1.reshape(encoder_shape1)

        #第二个人编码成功
        encoder_shape2 = x_encoder2.shape
        x_encoder2 = x_encoder2 if len(encoder_shape1) == 3 else x_encoder2.reshape(encoder_shape2[0], encoder_shape2[1], -1)
        # if verbose: print(f'reshape: {x_encoder.shape}')
        x_quantized2, code_idx2, commit_loss2, perplexity2 = self.quantizer(x_encoder2, sample_codebook_temp=0.5)
        x_quantized2 = x_quantized2.reshape(encoder_shape2)
        # if self.args.stage_select=="two":
        #     x_out1 = self.decoder(x_quantized1) # B, D=12, J=22, T || B, J=22xD=12, T
        #     x_out2 = self.decoder(x_quantized2) # B, D=12, J=22, T || B, J=22xD=12, T
        #     # if verbose: print(f'decoder: {x_out.shape}')
        #     x_out1 = self.postprocess(x_out1) # B,T,D=262
        #     x_out2 = self.postprocess(x_out2) # B,T,D=262
        #两人之间交互获得交互列表
        tx_quantized1,tx_quantized2=x_quantized1,x_quantized2
        if self.args.vq_select=="RSA_CA_Code":
            x_quantized1_,x_quantized2_,pyhsical_features,commit_loss3, perplexity3=self.RSA_CA_Code(tx_quantized1,tx_quantized2)
        elif self.args.vq_select=="RSA_Code":
            x_quantized1_,x_quantized2_,pyhsical_features,commit_loss3, perplexity3=self.RSA_Code(tx_quantized1,tx_quantized2)
        elif self.args.vq_select=="NoAttnCode":
            x_quantized1_,x_quantized2_,pyhsical_features,commit_loss3, perplexity3=self.NoAttnCode(tx_quantized1,tx_quantized2)
        ## decoder
        if self.args.stage_select=="one":
            x_out1 = self.decoder(x_quantized1_) # B, D=12, J=22, T || B, J=22xD=12, T
            x_out2 = self.decoder(x_quantized2_) # B, D=12, J=22, T || B, J=22xD=12, T
            # if verbose: print(f'decoder: {x_out.shape}')
            x_out1 = self.postprocess(x_out1) # B,T,D=262
            x_out2 = self.postprocess(x_out2) # B,T,D=262
            if self.args.physical_decoder:
                physical_out=self.physical_decoder(pyhsical_features)
                physical_out1,physical_out2=torch.split(physical_out,[physical_out.shape[1]//2,physical_out.shape[1]//2],dim=1)
                physical_out1 = self.postprocess(physical_out1)
                physical_out2 = self.postprocess(physical_out2)
            else:
                physical_out1,physical_out2=None,None
        elif self.args.stage_select=="two":
            #第一阶阶段计算单人损失
            x_out1 = self.decoder(x_quantized1) # B, D=12, J=22, T || B, J=22xD=12, T
            x_out2 = self.decoder(x_quantized2) # B, D=12, J=22, T || B, J=22xD=12, T
            # if verbose: print(f'decoder: {x_out.shape}')
            x_out1 = self.postprocess(x_out1) # B,T,D=262
            x_out2 = self.postprocess(x_out2) # B,T,D=262
            #第二阶段计算双人损失
            physical_out=self.physical_decoder(torch.cat([x_quantized1_,x_quantized2_],dim=1))
            physical_out1,physical_out2=torch.split(physical_out,[physical_out.shape[1]//2,physical_out.shape[1]//2],dim=1)
            physical_out1 = self.postprocess(physical_out1)
            physical_out2 = self.postprocess(physical_out2)
        return [x_out1,x_out2],[physical_out1,physical_out2], [(commit_loss1+commit_loss2)/2,commit_loss3], [(perplexity1+perplexity2)/2,perplexity3]

    def forward_decoder(self, x, x2=None, soft_lookup=False):
        
        
        if not soft_lookup:
            x1 = self.quantizer.get_codes_from_indices(x)
            x2 = self.quantizer.get_codes_from_indices(x2)
        else:
            x_d = self.quantizer.get_soft_codes_from_probs(x)
        x1 = x1.sum(dim=0).permute(0, 2, 1) # B,T,D=512 -> B,D,T
        x2 = x2.sum(dim=0).permute(0, 2, 1) # B,T,D=512 -> B,D,T
        if self.conv_dim == 2:
            x1 = x1.reshape(x1.shape[0], x1.shape[1], 5, x1.shape[2]//5) # B,D,T -> B,D,5,T/5
            x2 = x2.reshape(x2.shape[0], x2.shape[1], 5, x2.shape[2]//5) # B,D,T -> B,D,5,T/5
        x_quantized1,x_quantized2=x1,x2
        if self.args.vq_select=="RSA_CA_Code":
            x_quantized1,x_quantized2,pyhsical_features,commit_loss3, perplexity3=self.RSA_CA_Code(x_quantized1,x_quantized2)
        elif self.args.vq_select=="RSA_Code":
            x_quantized1,x_quantized2,pyhsical_features,commit_loss3, perplexity3=self.RSA_Code(x_quantized1,x_quantized2)
        elif self.args.vq_select=="NoAttnCode":
            x_quantized1,x_quantized2,pyhsical_features,commit_loss3, perplexity3=self.NoAttnCode(x_quantized1,x_quantized2)

        # x_quantized1,x_quantized2=self.Interaction(x_quantized1,x_quantized2)
        # pyhsical_features=torch.cat([x_quantized1,x_quantized2],dim=1)
        # pyhsical_features_shape=pyhsical_features.shape
        # pyhsical_features = pyhsical_features if len(pyhsical_features_shape) == 3 else pyhsical_features.reshape(pyhsical_features.shape[0], pyhsical_features.shape[1], -1)
        # pyhsical_features, _, _, _ = self.physical_quantizer(pyhsical_features, sample_codebook_temp=0.5)
        # pyhsical_features = pyhsical_features.reshape(pyhsical_features_shape)
        # pyhsical_features1,pyhsical_features2=torch.split(pyhsical_features,[x_quantized1.shape[1],x_quantized2.shape[1]],dim=1)
        # x_quantized1=x1+pyhsical_features1
        # x_quantized2=x2+pyhsical_features2
        ## decoder
        x_out1,x_out2=None  ,None
        if self.args.stage_select=="one":
            x_out1 = self.decoder(x_quantized1) # B, D=12, J=22, T || B, J=22xD=12, T
            x_out2 = self.decoder(x_quantized2) # B, D=12, J=22, T || B, J=22xD=12, T
            x_out1,x_out2=self.postprocess(x_out1),self.postprocess(x_out2)
        elif self.args.stage_select=="two":
            physical_out=self.physical_decoder(torch.cat([x_quantized1,x_quantized1],dim=1))
            physical_out1,physical_out2=torch.split(physical_out,[physical_out.shape[1]//2,physical_out.shape[1]//2],dim=1)
            x_out1 = self.postprocess(physical_out1)
            x_out2 = self.postprocess(physical_out2)
        return x_out1,x_out2
    



    def RSA_CA_Code(self,x_quantized1,x_quantized2):
        x_quantized1,x_quantized2,_,_=self.Interaction(x_quantized1,x_quantized2)
        pyhsical_features=torch.cat([x_quantized1,x_quantized2],dim=1)
        pyhsical_features_shape=pyhsical_features.shape
        pyhsical_features = pyhsical_features if len(pyhsical_features_shape) == 3 else pyhsical_features.reshape(pyhsical_features.shape[0], pyhsical_features.shape[1], -1)
        pyhsical_features, code_idx3, commit_loss3, perplexity3 = self.physical_quantizer(pyhsical_features, sample_codebook_temp=0.5)
        pyhsical_features = pyhsical_features.reshape(pyhsical_features_shape)
        pyhsical_features1,pyhsical_features2=torch.split(pyhsical_features,[x_quantized1.shape[1],x_quantized2.shape[1]],dim=1)
        if self.args.gate_select:
            x_quantized1=self.alpha*x_quantized1+pyhsical_features1*(1-self.alpha)
            x_quantized2=self.beta*x_quantized2+pyhsical_features2*(1-self.beta)
        else:   
            x_quantized1=x_quantized1+pyhsical_features1
            x_quantized2=x_quantized2+pyhsical_features2
        return x_quantized1,x_quantized2,pyhsical_features,commit_loss3, perplexity3
    def RSA_Code(self,x_quantized1,x_quantized2):
        t_x_quantized1,t_x_quantized2,x_quantized1,x_quantized2=self.Interaction(x_quantized1,x_quantized2)
        pyhsical_features=torch.cat([x_quantized1,x_quantized2],dim=1)
        pyhsical_features_shape=pyhsical_features.shape
        pyhsical_features = pyhsical_features if len(pyhsical_features_shape) == 3 else pyhsical_features.reshape(pyhsical_features.shape[0], pyhsical_features.shape[1], -1)
        pyhsical_features, code_idx3, commit_loss3, perplexity3 = self.physical_quantizer(pyhsical_features, sample_codebook_temp=0.5)
        pyhsical_features = pyhsical_features.reshape(pyhsical_features_shape)
        pyhsical_features1,pyhsical_features2=torch.split(pyhsical_features,[x_quantized1.shape[1],x_quantized2.shape[1]],dim=1)
        if self.args.gate_select:
            x_quantized1=self.alpha*t_x_quantized1+pyhsical_features1*(1-self.alpha)
            x_quantized2=self.beta*t_x_quantized2+pyhsical_features2*(1-self.beta)
        else:
            x_quantized1=t_x_quantized1+pyhsical_features1
            x_quantized2=t_x_quantized2+pyhsical_features2
        return x_quantized1,x_quantized2,pyhsical_features,commit_loss3, perplexity3
    def NoAttnCode(self,x_quantized1,x_quantized2):
        t_x_quantized1,t_x_quantized2=x_quantized1,x_quantized2
        x_quantized1,x_quantized2,_,_=self.Interaction(x_quantized1,x_quantized2)
        pyhsical_features=torch.cat([x_quantized1,x_quantized2],dim=1)
        pyhsical_features_shape=pyhsical_features.shape
        pyhsical_features = pyhsical_features if len(pyhsical_features_shape) == 3 else pyhsical_features.reshape(pyhsical_features.shape[0], pyhsical_features.shape[1], -1)
        pyhsical_features, code_idx3, commit_loss3, perplexity3 = self.physical_quantizer(pyhsical_features, sample_codebook_temp=0.5)
        pyhsical_features = pyhsical_features.reshape(pyhsical_features_shape)
        pyhsical_features1,pyhsical_features2=torch.split(pyhsical_features,[x_quantized1.shape[1],x_quantized2.shape[1]],dim=1)
        if self.args.gate_select:
            x_quantized1=self.alpha*t_x_quantized1+pyhsical_features1*(1-self.alpha)
            x_quantized2=self.beta*t_x_quantized2+pyhsical_features2*(1-self.beta)
        else:
            x_quantized1=t_x_quantized1+pyhsical_features1
            x_quantized2=t_x_quantized2+pyhsical_features2
        return x_quantized1,x_quantized2,pyhsical_features,commit_loss3, perplexity3